import numpy as np
import torch

from src.envs.base_environment import ContinuousEnvironment

def angle_normalize(x):
    return ((x + np.pi) % (2 * np.pi)) - np.pi

class PendulumEnvironment(ContinuousEnvironment):
    """
    ### Description

    The inverted pendulum swingup problem is based on the classic problem in control theory.
    The system consists of a pendulum attached at one end to a fixed point, and the other end being free.
    The pendulum starts in a random position and the goal is to apply impulses on the free end to swing it
    into an upright position. The reward is based on the angle of the pendulum and its angular velocity after
    applying a series of impluses at regular intervals. Higher rewards are given for the pendulum being upright
    with zero velocity.

    ### Action Space

    | Num | Action  | Min               | Max              |
    |-----|---------|-------------------|------------------|
    | 0   | Impulse | -self.max_impulse | self.max_impulse |


    ### Observation Space

    | Num | Observation      | Min             | Max            |
    |-----|------------------|-----------------|----------------|
    | 0   | Theta            | -pi             | pi             |
    | 1   | Angular Velocity | -self.max_speed | self.max_speed |

    ### Rewards

    The reward density is defined as:

    r = -(th^2 + 0.1 * thdot^2)

    ### Policy Parameterisation

    The policy is parameterised as a mixture model with `mixture_dim` components.
    The mixture is Gaussian for the Angular Velocity and is a von Mises distribution for the Theta coordinates.

    ### Arguments

    - dt: time step
    - g: acceleration of gravity
    - m: mass of the pendulum rod
    - l: length of the pendulum rod
    - interval_between_steps`: time interval between which torque impulses are applied to the pendulum
    - max_speed: maximum angular velocity of the pendulum
    - max_impulse: maximum impulse that can be applied to the pendulum
    - max_impulse_std: maximum standard deviation of the impulse Gaussian policy
    - min_impulse_std: minimum standard deviation of the impulse Gaussian policy
    - mixture_dim: number of components in the Gaussian mixture model in the parameterisation of the policy

    All units are in SI units.
    """

    def __init__(self, config):
        self._init_required_params(config)
        lower_bound = [-np.pi, -self.max_speed]
        upper_bound = [np.pi, self.max_speed]
        super().__init__(config,
                         dim = 2,
                         feature_dim = 3,
                         angle_dim = [True, False],
                         action_dim =  1,
                         lower_bound= lower_bound,
                         upper_bound= upper_bound,
                         mixture_dim=config["env"]["mixture_dim"],
                         output_dim=config["env"]["mixture_dim"]*3)

    def _init_required_params(self, config):
        # Environment-specific parameters
        required_params = ["max_speed", "max_impulse", "max_impulse_std", "min_impulse_std", "timestep", "gravity", "mass", "length", "interval_between_steps"]
        assert all(param in config["env"] for param in required_params), f"Missing required parameters: {required_params}"

        self.max_speed = config["env"]["max_speed"]                           # default 8
        self.max_impulse = config["env"]["max_impulse"]                       # default 2
        self.max_impulse_std = config["env"]["max_impulse_std"]               # default 0.5
        self.min_impulse_std = config["env"]["min_impulse_std"]               # default 0.1
        self.dt = config["env"]["timestep"]                                   # default 0.05
        self.g = config["env"]["gravity"]                                     # default 10.0
        self.m = config["env"]["mass"]                                        # default 1.0
        self.l = config["env"]["length"]                                      # default 1.0
        self.interval_between_steps = config["env"]["interval_between_steps"] # default 0.5
        assert self.interval_between_steps > 2 * self.dt, "Interval between steps should be greater than 2 * timestep" # This is the time interval between which torque impulses are applied to the pendulum
    
    def log_reward(self, x):
        """Returns the reward of the state x."""
        return -x[..., 0] ** 2 - 0.1 * x[..., 1] ** 2
    
    def featurisation(self, states):
        featurised_states = torch.zeros(states.shape[0], self.feature_dim + 1)
        featurised_states[:, 0] = torch.cos(states[:, 0])
        featurised_states[:, 1] = torch.sin(states[:, 0])
        featurised_states[:, 2] = states[:, 1]
        featurised_states[:, 3] = states[:, 2]

        return featurised_states

    def step(self, x, action):
        #action corresponds to the rotational delta impulse to impart on the pendulum (units Nms)

        th, thdot, idx = x[:, 0], x[:, 1], x[:, 2]

        # Initial impulse update
        newthdot = thdot + 3 * action.squeeze() / (self.m * self.l**2)
        newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)
        newth = th + newthdot * self.dt

        newth, newthdot = self._forward_freefall(newth, newthdot)

        new_x = torch.zeros_like(x)
        new_x[:, 0] = angle_normalize(newth)
        new_x[:, 1] = newthdot
        new_x[:, 2] = idx + 1

        return new_x
    
    def backward_step(self, x, action):
        th, thdot, idx = x[:, 0], x[:, 1], x[:, 2]

        newth, newthdot = self._backward_freefall(th, thdot)

        # Initial impulse update
        newthdot = thdot - 3 * action.squeeze() / (self.m * self.l**2)
        newthdot = np.clip(newthdot, -self.max_speed, self.max_speed)
        newth = newth - newthdot * self.dt

        new_x = torch.zeros_like(x)
        new_x[:, 0] = angle_normalize(newth)
        new_x[:, 1] = newthdot
        new_x[:, 2] = idx - 1

        return new_x
    
    def compute_initial_action(self, first_state):
        # Interval bisecting to find the impulse that will bring the pendulum from init_state to angle at first_state[0]
        th, thdot = first_state[:,0], first_state[:,1]

        direction_positive = torch.sin(th - self.init_value[0]) > 0

        # Initialize low and high bounds
        low = torch.where(direction_positive, torch.zeros_like(th), -self.max_impulse * torch.ones_like(th))
        high = torch.where(direction_positive, self.max_impulse * torch.ones_like(th), torch.zeros_like(th))

        for _ in range(100):
            trial_impulse = (low + high) / 2
            trial_velocity = 3 * trial_impulse / (self.m * self.l**2)
            trial_velocity = torch.clamp(trial_velocity, -self.max_speed, self.max_speed)
            newth, newthdot = self._forward_freefall(self.init_value[0], trial_velocity)
            
            swing_past_positive = torch.sin(newth - th) > 0
            low_update_positive = torch.where(swing_past_positive, low, trial_impulse)
            high_update_positive = torch.where(swing_past_positive, trial_impulse, high)
            
            swing_past_negative = torch.sin(newth - th) < 0
            low_update_negative = torch.where(swing_past_negative, trial_impulse, low)
            high_update_negative = torch.where(swing_past_negative, high, trial_impulse)
            
            low = torch.where(direction_positive, low_update_positive, low_update_negative)
            high = torch.where(direction_positive, high_update_positive, high_update_negative)
            
            if torch.max(torch.abs(newth - th)) < 1e-3:
                break

        action = trial_impulse.to(self.device)

        return action.unsqueeze(-1)
    
    def postprocess_params(self, params):
        mu_params, std_params, weight_params = params[:, :self.mixture_dim], params[:, self.mixture_dim: 2 * self.mixture_dim], params[:, 2 * self.mixture_dim:]
        # Restrict the policy mean so that the mean state after the action is within the domain
        mus = 2 * self.max_impulse * torch.sigmoid(mu_params) - self.max_impulse
        # Restrict the policy std to be within the specified bounds
        stds = torch.sigmoid(std_params) * (self.max_impulse_std - self.min_impulse_std) + self.min_impulse_std
        # Normalise the weights of the Gaussian mixture model
        weights = torch.softmax(weight_params, dim=1)
        param_dict = {"mus": mus, "stds": stds, "weights": weights}
        return param_dict   
    
    def add_noise(self, param_dict, off_policy_noise):
        param_dict["stds"] += off_policy_noise

        return param_dict

    def _forward_freefall(self, th, thdot):
        # Free fall update for the remaining time
        steps_to_take = int((self.interval_between_steps - self.dt) / self.dt)
        for _ in range(steps_to_take):
            thdot = thdot + (3 * self.g / (2 * self.l) * np.sin(th)) * self.dt 
            thdot = np.clip(thdot, -self.max_speed, self.max_speed)
            th = th + thdot * self.dt

        return th, thdot
    
    def _backward_freefall(self, th, thdot):
        # Free fall update for the remaining time
        steps_to_take = int((self.interval_between_steps - self.dt) / self.dt)
        for _ in range(steps_to_take):
            thdot = thdot - (3 * self.g / (2 * self.l) * np.sin(th)) * self.dt 
            thdot = np.clip(thdot, -self.max_speed, self.max_speed)
            th = th - thdot * self.dt

        return th, thdot
    
    def _init_policy_dist(self, param_dict):
        mus, stds, weights = param_dict["mus"], param_dict["stds"], param_dict["weights"]
        mix = torch.distributions.Categorical(weights)
        comp = torch.distributions.Independent(torch.distributions.Normal(mus.unsqueeze(-1), stds.unsqueeze(-1)), 1)

        return torch.distributions.MixtureSameFamily(mix, comp)
